Deep Q-Networks (DQN) for Discrete Action Spaces — Low-Level PyTorch#
DQN combines Q-learning with a neural network to approximate an action-value function \(Q_\theta(s, a)\) when actions are discrete.
What you’ll learn#
the Bellman expectation and optimality equations (precisely, with LaTeX)
why experience replay stabilizes learning, and how to implement it
a minimal DQN in PyTorch: replay buffer, target network, \(\epsilon\)-greedy exploration
Plotly diagnostics: reward per episode, TD loss, and learned \(Q\)-values
a Stable-Baselines3 DQN reference + hyperparameter explanations (end)
Notebook roadmap#
RL notation + Bellman equations
DQN targets + loss
Experience replay (precise)
Target network updates
Low-level PyTorch implementation (from scratch)
Train on a tiny toy environment (no extra RL dependencies)
Plotly diagnostics: reward per episode, loss, and \(Q\)-values
Stable-Baselines3 DQN reference + hyperparameters (end)
import math
import random
from dataclasses import dataclass
import numpy as np
import plotly.express as px
import plotly.graph_objects as go
import os
import plotly.io as pio
import torch
import torch.nn as nn
import torch.nn.functional as F
pio.templates.default = "plotly_white"
pio.renderers.default = os.environ.get("PLOTLY_RENDERER", "notebook")
SEED = 7
random.seed(SEED)
np.random.seed(SEED)
rng = np.random.default_rng(SEED)
torch.manual_seed(SEED)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(SEED)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
device
/home/tempa/miniconda3/lib/python3.12/site-packages/torch/cuda/__init__.py:174: UserWarning:
CUDA initialization: CUDA unknown error - this may be due to an incorrectly set up environment, e.g. changing env variable CUDA_VISIBLE_DEVICES after program start. Setting the available devices to be zero. (Triggered internally at /pytorch/c10/cuda/CUDAFunctions.cpp:109.)
device(type='cpu')
Prerequisites#
basic probability and expectation
comfort with vectors/matrices and simple neural nets
familiarity with PyTorch tensors and optimizers
1) Setup: MDPs and returns#
We assume an episodic Markov Decision Process (MDP):
states \(s \in \mathcal{S}\)
actions \(a \in \mathcal{A}\) (here: discrete)
transition kernel \(P(s' \mid s, a)\)
reward \(r_{t+1}\) after taking action \(a_t\) in state \(s_t\)
discount \(\gamma \in [0, 1)\)
The (discounted) return from time \(t\) is
2) Action-value functions#
For a policy \(\pi(a\mid s)\), the action-value function is
If actions are discrete, we can represent the action-values as a vector
and pick greedy actions via
3) Bellman equations (precise)#
3.1 Bellman expectation equation#
Under policy \(\pi\), the Bellman expectation equation for \(Q^{\pi}\) is
3.2 Bellman optimality equation#
Define the optimal action-value function
Then \(Q^*\) satisfies
3.3 Tabular Q-learning update (intuition)#
Q-learning performs stochastic approximation to the Bellman optimality equation:
4) From Q-learning to DQN (targets + loss)#
When \(\mathcal{S}\) is large or continuous, we approximate action-values with a neural network \(Q_\theta(s,a)\).
Given a transition \((s, a, r, s', d)\) where \(d \in \{0,1\}\) indicates a terminal transition (1 if episode ended due to an environment terminal condition), DQN uses the TD target
where \(\theta^-\) are parameters of a target network (a delayed copy of \(\theta\)).
We learn \(\theta\) by minimizing a TD regression loss over replayed transitions:
Two common choices for \(\ell\):
squared loss: \(\ell(u,v) = (u-v)^2\)
Huber loss (Smooth L1), more robust to large TD errors:
5) Experience replay (precisely)#
DQN is off-policy: we generate experience using a behavior policy (typically \(\epsilon\)-greedy) while learning the greedy \(Q\)-function.
We store transitions in a replay buffer:
At each gradient step we sample an (approximately) i.i.d. mini-batch uniformly:
Why this helps:
Temporal decorrelation: online trajectories yield highly correlated samples; replay makes SGD closer to its i.i.d. assumptions.
Data reuse: each transition can be used for many gradient steps, improving sample-efficiency.
Stabilization: mixing older and newer experience reduces non-stationarity of the training distribution.
We use uniform replay here for clarity (prioritized replay is a common extension).
6) Target network updates#
A moving target \(y\) can destabilize training. DQN stabilizes learning by using a target network \(Q_{\theta^-}\).
Two standard update schemes:
hard update every \(C\) steps:
soft / Polyak update:
In the implementation below, setting \(\tau=1\) with periodic updates is a hard update.
7) Low-level PyTorch implementation (from scratch)#
We implement:
a tiny environment (so the notebook runs without external RL dependencies)
replay buffer with uniform sampling
an MLP \(Q\)-network
the DQN loop with \(\epsilon\) scheduling + target network updates
@dataclass(frozen=True)
class DQNConfig:
gamma: float = 0.99
learning_rate: float = 1e-3
buffer_size: int = 50_000
batch_size: int = 64
learning_starts: int = 1_000
train_freq: int = 1
gradient_steps: int = 1
target_update_interval: int = 500
tau: float = 1.0
eps_start: float = 1.0
eps_end: float = 0.05
eps_fraction: float = 0.8
max_grad_norm: float = 10.0
hidden_sizes: tuple = (128, 128)
def linear_schedule(start: float, end: float, duration: int, t: int) -> float:
if duration <= 0:
return end
frac = min(max(t / float(duration), 0.0), 1.0)
return start + frac * (end - start)
def moving_average(x, window: int):
x = np.asarray(x, dtype=np.float64)
if window <= 1:
return x
if x.size < window:
return np.full_like(x, np.nan, dtype=np.float64)
kernel = np.ones(int(window), dtype=np.float64) / float(window)
ma = np.convolve(x, kernel, mode="valid")
return np.concatenate([np.full(window - 1, np.nan), ma])
def reset_env(env, seed=None):
out = env.reset(seed=int(seed)) if seed is not None else env.reset()
if isinstance(out, tuple) and len(out) == 2:
obs, info = out
else:
obs, info = out, {}
return obs, info
def step_env(env, action: int):
out = env.step(int(action))
if isinstance(out, tuple) and len(out) == 5:
obs, reward, terminated, truncated, info = out
return obs, float(reward), bool(terminated), bool(truncated), info
if isinstance(out, tuple) and len(out) == 4:
obs, reward, done, info = out
return obs, float(reward), bool(done), False, info
raise ValueError(f"Unexpected step() return of length {len(out)}")
@torch.no_grad()
def polyak_update(target_net: nn.Module, online_net: nn.Module, tau: float):
for p_targ, p in zip(target_net.parameters(), online_net.parameters()):
p_targ.data.mul_(1.0 - tau)
p_targ.data.add_(p.data, alpha=tau)
def infer_n_actions(env) -> int:
if hasattr(env, "action_space") and hasattr(env.action_space, "n"):
return int(env.action_space.n)
if hasattr(env, "n_actions"):
return int(env.n_actions)
raise ValueError("Cannot infer number of actions: expected env.action_space.n or env.n_actions")
def select_action(q_net: nn.Module, obs: np.ndarray, epsilon: float, n_actions: int, rng: np.random.Generator) -> int:
if rng.random() < epsilon:
return int(rng.integers(0, n_actions))
obs_t = torch.as_tensor(obs, device=device, dtype=torch.float32).unsqueeze(0)
with torch.no_grad():
q = q_net(obs_t)
return int(torch.argmax(q, dim=1).item())
class LineWorldEnv:
"""A tiny 1D environment with discrete actions.
- states are positions {0, 1, ..., n_states-1}
- actions: 0=left, 1=right
- observation is a one-hot vector in R^{n_states}
- reward: step_penalty each step, +goal_reward when reaching the goal
"""
def __init__(
self,
n_states: int = 15,
max_steps: int = 40,
step_penalty: float = -0.01,
goal_reward: float = 1.0,
slip_prob: float = 0.0,
):
if n_states < 2:
raise ValueError("n_states must be >= 2")
if max_steps < 1:
raise ValueError("max_steps must be >= 1")
if not (0.0 <= slip_prob <= 1.0):
raise ValueError("slip_prob must be in [0, 1]")
self.n_states = int(n_states)
self.n_actions = 2
self.max_steps = int(max_steps)
self.step_penalty = float(step_penalty)
self.goal_reward = float(goal_reward)
self.slip_prob = float(slip_prob)
self._pos = 0
self._t = 0
self._rng = np.random.default_rng(0)
def _obs(self) -> np.ndarray:
obs = np.zeros((self.n_states,), dtype=np.float32)
obs[self._pos] = 1.0
return obs
def reset(self, seed=None, options=None):
if seed is not None:
self._rng = np.random.default_rng(int(seed))
self._pos = 0
self._t = 0
return self._obs(), {}
def step(self, action: int):
action = int(action)
if self.slip_prob > 0.0 and self._rng.random() < self.slip_prob:
action = 1 - action # slip: flip action
if action == 0:
self._pos = max(0, self._pos - 1)
elif action == 1:
self._pos = min(self.n_states - 1, self._pos + 1)
else:
raise ValueError("action must be 0 (left) or 1 (right)")
self._t += 1
terminated = self._pos == (self.n_states - 1)
truncated = self._t >= self.max_steps
reward = self.step_penalty + (self.goal_reward if terminated else 0.0)
info = {"pos": self._pos}
return self._obs(), float(reward), bool(terminated), bool(truncated), info
class ReplayBuffer:
def __init__(self, capacity: int, obs_shape, device: torch.device):
self.capacity = int(capacity)
self.device = device
self.obs_buf = np.zeros((self.capacity, *obs_shape), dtype=np.float32)
self.next_obs_buf = np.zeros((self.capacity, *obs_shape), dtype=np.float32)
self.actions_buf = np.zeros((self.capacity,), dtype=np.int64)
self.rewards_buf = np.zeros((self.capacity,), dtype=np.float32)
self.dones_buf = np.zeros((self.capacity,), dtype=np.float32)
self.ptr = 0
self.size = 0
def add(self, obs, action: int, reward: float, next_obs, done: float):
self.obs_buf[self.ptr] = obs
self.next_obs_buf[self.ptr] = next_obs
self.actions_buf[self.ptr] = int(action)
self.rewards_buf[self.ptr] = float(reward)
self.dones_buf[self.ptr] = float(done)
self.ptr = (self.ptr + 1) % self.capacity
self.size = min(self.size + 1, self.capacity)
def sample(self, batch_size: int, rng: np.random.Generator):
if self.size < batch_size:
raise ValueError("Not enough samples in buffer")
idx = rng.integers(0, self.size, size=int(batch_size))
obs = torch.as_tensor(self.obs_buf[idx], device=self.device, dtype=torch.float32)
next_obs = torch.as_tensor(self.next_obs_buf[idx], device=self.device, dtype=torch.float32)
actions = torch.as_tensor(self.actions_buf[idx], device=self.device, dtype=torch.int64)
rewards = torch.as_tensor(self.rewards_buf[idx], device=self.device, dtype=torch.float32)
dones = torch.as_tensor(self.dones_buf[idx], device=self.device, dtype=torch.float32)
return obs, actions, rewards, next_obs, dones
class QNetwork(nn.Module):
def __init__(self, obs_dim: int, n_actions: int, hidden_sizes=(128, 128)):
super().__init__()
layers = []
in_dim = int(obs_dim)
for h in hidden_sizes:
layers.append(nn.Linear(in_dim, int(h)))
layers.append(nn.ReLU())
in_dim = int(h)
layers.append(nn.Linear(in_dim, int(n_actions)))
self.net = nn.Sequential(*layers)
def forward(self, x: torch.Tensor) -> torch.Tensor:
return self.net(x)
def dqn_update(
q_net: nn.Module,
target_net: nn.Module,
optimizer: torch.optim.Optimizer,
batch,
gamma: float,
max_grad_norm: float,
) -> float:
obs, actions, rewards, next_obs, dones = batch
q_values = q_net(obs) # (B, A)
q_sa = q_values.gather(1, actions.unsqueeze(1)).squeeze(1) # (B,)
with torch.no_grad():
next_q = target_net(next_obs)
max_next_q = next_q.max(dim=1).values
target = rewards + gamma * (1.0 - dones) * max_next_q
loss = F.smooth_l1_loss(q_sa, target)
optimizer.zero_grad(set_to_none=True)
loss.backward()
nn.utils.clip_grad_norm_(q_net.parameters(), max_norm=max_grad_norm)
optimizer.step()
return float(loss.item())
def train_dqn(env, config: DQNConfig, num_episodes: int = 300, log_every: int = 50):
n_actions = infer_n_actions(env)
obs0, _ = reset_env(env, seed=int(rng.integers(0, 1_000_000)))
obs0 = np.asarray(obs0, dtype=np.float32)
if obs0.ndim != 1:
raise ValueError("This minimal notebook assumes 1D vector observations")
obs_dim = int(obs0.shape[0])
obs_shape = obs0.shape
max_steps_per_episode = int(getattr(env, "max_steps", 200))
eps_decay_steps = max(1, int(config.eps_fraction * num_episodes * max_steps_per_episode))
q_net = QNetwork(obs_dim, n_actions, hidden_sizes=config.hidden_sizes).to(device)
target_net = QNetwork(obs_dim, n_actions, hidden_sizes=config.hidden_sizes).to(device)
target_net.load_state_dict(q_net.state_dict())
target_net.eval()
optimizer = torch.optim.Adam(q_net.parameters(), lr=config.learning_rate)
buffer = ReplayBuffer(config.buffer_size, obs_shape=obs_shape, device=device)
global_step = 0
episode_rewards = []
episode_lengths = []
epsilons = []
loss_steps = []
loss_values = []
q_probe_history = []
probe_obs = obs0.copy()
for ep in range(int(num_episodes)):
obs, _ = reset_env(env, seed=int(rng.integers(0, 1_000_000)))
obs = np.asarray(obs, dtype=np.float32)
total_reward = 0.0
steps = 0
terminated = False
truncated = False
while not (terminated or truncated):
epsilon = linear_schedule(config.eps_start, config.eps_end, eps_decay_steps, global_step)
action = select_action(q_net, obs, epsilon, n_actions, rng)
next_obs, reward, terminated, truncated, _ = step_env(env, action)
next_obs = np.asarray(next_obs, dtype=np.float32)
done_for_bootstrap = float(terminated) # time-limit truncation should still bootstrap
buffer.add(obs, action, reward, next_obs, done_for_bootstrap)
obs = next_obs
total_reward += float(reward)
steps += 1
global_step += 1
if (
buffer.size >= config.batch_size
and global_step >= config.learning_starts
and (global_step % config.train_freq == 0)
):
for _ in range(config.gradient_steps):
batch = buffer.sample(config.batch_size, rng)
loss = dqn_update(
q_net=q_net,
target_net=target_net,
optimizer=optimizer,
batch=batch,
gamma=config.gamma,
max_grad_norm=config.max_grad_norm,
)
loss_steps.append(global_step)
loss_values.append(loss)
if global_step % config.target_update_interval == 0:
polyak_update(target_net, q_net, tau=config.tau)
if steps >= max_steps_per_episode:
truncated = True
episode_rewards.append(total_reward)
episode_lengths.append(steps)
epsilons.append(epsilon)
with torch.no_grad():
q_probe = q_net(torch.as_tensor(probe_obs, device=device, dtype=torch.float32).unsqueeze(0))
q_probe_history.append(q_probe.squeeze(0).cpu().numpy())
if log_every and (ep + 1) % int(log_every) == 0:
print(
f"Episode {ep+1:4d} | reward {total_reward:7.2f} | "
f"eps {epsilon:5.3f} | buffer {buffer.size:6d} | steps {global_step:7d}"
)
logs = {
"episode_rewards": np.asarray(episode_rewards, dtype=np.float64),
"episode_lengths": np.asarray(episode_lengths, dtype=np.int64),
"epsilons": np.asarray(epsilons, dtype=np.float64),
"loss_steps": np.asarray(loss_steps, dtype=np.int64),
"loss_values": np.asarray(loss_values, dtype=np.float64),
"q_probe": np.asarray(q_probe_history, dtype=np.float64),
"n_actions": n_actions,
"obs_dim": obs_dim,
}
return q_net, logs
env = LineWorldEnv(n_states=15, max_steps=40, step_penalty=-0.01, goal_reward=1.0, slip_prob=0.05)
config = DQNConfig(
gamma=0.99,
learning_rate=1e-3,
buffer_size=25_000,
batch_size=64,
learning_starts=500,
train_freq=1,
gradient_steps=1,
target_update_interval=200,
tau=1.0,
eps_start=1.0,
eps_end=0.05,
eps_fraction=0.8,
max_grad_norm=10.0,
hidden_sizes=(128, 128),
)
q_net, logs = train_dqn(env, config=config, num_episodes=300, log_every=50)
logs.keys()
Episode 50 | reward -0.40 | eps 0.802 | buffer 1999 | steps 1999
Episode 100 | reward -0.40 | eps 0.612 | buffer 3917 | steps 3917
Episode 150 | reward 0.69 | eps 0.465 | buffer 5403 | steps 5403
Episode 200 | reward 0.85 | eps 0.338 | buffer 6690 | steps 6690
Episode 250 | reward 0.78 | eps 0.225 | buffer 7833 | steps 7833
Episode 300 | reward 0.74 | eps 0.130 | buffer 8792 | steps 8792
dict_keys(['episode_rewards', 'episode_lengths', 'epsilons', 'loss_steps', 'loss_values', 'q_probe', 'n_actions', 'obs_dim'])
# Reward per episode (learning curve)
rewards = logs["episode_rewards"]
ma = moving_average(rewards, window=20)
fig = go.Figure()
fig.add_trace(go.Scatter(y=rewards, mode="lines", name="reward/episode"))
fig.add_trace(go.Scatter(y=ma, mode="lines", name="moving avg (20)", line=dict(width=3)))
fig.update_layout(
title="DQN learning curve (reward per episode)",
xaxis_title="episode",
yaxis_title="total reward",
)
fig.show()
# TD loss over training steps
loss_steps = logs["loss_steps"]
loss_values = logs["loss_values"]
if loss_values.size == 0:
print("No loss values recorded (try lowering learning_starts or increasing episodes).")
else:
fig = px.line(
x=loss_steps,
y=loss_values,
labels={"x": "environment step", "y": "Huber TD loss"},
title="DQN TD loss (Smooth L1)",
)
fig.show()
# Exploration schedule ($\epsilon$) over episodes
eps = logs["epsilons"]
fig = px.line(
x=np.arange(len(eps)),
y=eps,
labels={"x": "episode", "y": "epsilon"},
title="Epsilon schedule",
)
fig.show()
# Q-values over training for a fixed probe state (the start state)
q_probe = logs["q_probe"] # (episodes, n_actions)
n_actions = logs["n_actions"]
fig = go.Figure()
for a in range(n_actions):
fig.add_trace(go.Scatter(y=q_probe[:, a], mode="lines", name=f"Q(probe, a={a})"))
fig.update_layout(
title="Learned Q-values over training (probe state)",
xaxis_title="episode",
yaxis_title="Q-value",
)
fig.show()
# Heatmap of Q(s,a) for all LineWorld states (since observations are one-hot)
states = np.eye(env.n_states, dtype=np.float32)
with torch.no_grad():
q_all = q_net(torch.as_tensor(states, device=device, dtype=torch.float32)).cpu().numpy()
fig = px.imshow(
q_all,
aspect="auto",
origin="lower",
labels={"x": "action", "y": "state", "color": "Q(s,a)"},
title="Q-values heatmap across states and actions",
color_continuous_scale="RdBu",
)
fig.update_xaxes(tickmode="array", tickvals=list(range(env.n_actions)), ticktext=["left", "right"])
fig.show()
Pitfalls + diagnostics#
Diverging loss / exploding Q-values: try a smaller learning rate, larger replay buffer, or gradient clipping.
No learning: ensure enough exploration (\(\epsilon\) schedule), lower
learning_starts, increase training episodes.Time-limit truncation: if an episode ends due to a time limit (not a true terminal state), bootstrapping should still happen. In the code above we store \(d=1\) only when
terminated=True.Target update too frequent/rare: too frequent makes targets chase the online net; too rare slows learning.
Reward scale: large rewards can increase TD errors; Huber loss helps.
(Optional) Run the same DQN on Gymnasium environments#
If you have gymnasium installed, you can plug it into the same training code as long as:
the action space is discrete (\(|\mathcal{A}| < \infty\))
the observation can be represented as a 1D float vector (this notebook uses an MLP)
Example (CartPole):
import gymnasium as gym
env = gym.make("CartPole-v1")
q_net, logs = train_dqn(env, config=config, num_episodes=400, log_every=50)
If gymnasium is missing, install with:
pip install gymnasium
Stable-Baselines3 DQN (reference implementation)#
Stable-Baselines3 includes a DQN implementation for discrete action spaces. Example from the SB3 docs:
import gymnasium as gym
from stable_baselines3 import DQN
env = gym.make("CartPole-v1")
model = DQN("MlpPolicy", env, verbose=1)
model.learn(total_timesteps=10_000, log_interval=4)
model.save("dqn_cartpole")
Docs:
https://stable-baselines3.readthedocs.io/en/master/modules/dqn.html
https://github.com/DLR-RM/stable-baselines3
Original DQN paper:
Mnih et al. (2015), Human-level control through deep reinforcement learning
Stable-Baselines3 DQN(...) hyperparameters (explained)#
Below are the most important SB3 DQN constructor hyperparameters and what they control:
policy: network/policy type (e.g.,"MlpPolicy","CnnPolicy","MultiInputPolicy").learning_rate: optimizer step size (can be a float or a schedule).buffer_size: replay buffer capacity (number of transitions stored).learning_starts: number of environment steps collected before training begins.batch_size: mini-batch size sampled from the replay buffer.gamma: discount factor \(\gamma\).train_freq: how often to run training updates while collecting data (e.g., every 1 or 4 environment steps).gradient_steps: how many gradient steps to take per training iteration.target_update_interval: how often to update the target network (in environment steps).tau: Polyak coefficient for target updates (\(\tau=1\) corresponds to a hard copy).exploration_initial_eps,exploration_final_eps,exploration_fraction: linear \(\epsilon\)-greedy exploration schedule.max_grad_norm: gradient clipping threshold.policy_kwargs: network architecture and optimizer details (e.g.,net_arch, activation, optimizer settings).replay_buffer_class,replay_buffer_kwargs: swap/parameterize the replay buffer (e.g., HER buffers for goal-conditioned tasks).optimize_memory_usage: enables a more memory-efficient replay buffer variant (useful with large observations).device: where to run the model ("cpu"or"cuda").seed,verbose,tensorboard_log: reproducibility and logging controls.